![]() | ![]() | | ![]() |
Ky tutorial demonstron se si të rregulloni modelin RecurrentGemma 2B Instruct për një detyrë përkthimi anglisht-francez duke përdorur bibliotekën recurrentgemma
të Google DeepMind , JAX (një bibliotekë kompjuterike numerike me performancë të lartë), Flax (bibliotekë e rrjetit nervor të bazuar në JAX), Chex (një bibliotekë për shkrimin e opcioneve të besueshme të JAX-së) Biblioteka e përpunimit dhe optimizimit të gradientit), dhe grupi i të dhënave MTNT (Përkthimi makine i tekstit të zhurmshëm) . Megjithëse Liri nuk përdoret drejtpërdrejt në këtë fletore, Liri u përdor për të krijuar Gemma.
Biblioteka recurrentgemma
u shkrua me JAX, Flax, Orbax (një bibliotekë e bazuar në JAX për shërbimet e trajnimit si pika e kontrollit) dhe SentencePiece (një bibliotekë tokenizues/detokenizues).
Kjo fletore mund të funksionojë në Google Colab me GPU T4 (shkoni te Edit > Cilësimet e Notebook > Nën Përshpejtuesin Hardware zgjidhni T4 GPU ).
Konfigurimi
Seksionet e mëposhtme shpjegojnë hapat për përgatitjen e një fletoreje për të përdorur një model RecurrentGemma, duke përfshirë aksesin në model, marrjen e një çelësi API dhe konfigurimin e kohës së funksionimit të fletores.
Konfiguro qasjen në Kaggle për Gemma
Për të përfunduar këtë tutorial, së pari duhet të ndiqni udhëzimet e konfigurimit të ngjashme me konfigurimin e Gemma me disa përjashtime:
- Merrni akses në RecurrentGemma (në vend të Gemma) në kaggle.com .
- Zgjidhni një kohë ekzekutimi Colab me burime të mjaftueshme për të ekzekutuar modelin RecurrentGemma.
- Gjeneroni dhe konfiguroni një emër përdoruesi dhe çelës API të Kaggle.
Pasi të keni përfunduar konfigurimin e RecurrentGemma, kaloni në seksionin tjetër, ku do të vendosni variablat e mjedisit për mjedisin tuaj Colab.
Vendosni variablat e mjedisit
Cakto variablat e mjedisit për KAGGLE_USERNAME
dhe KAGGLE_KEY
. Kur kërkohet me "Grant qasje?" mesazhe, pranoni të siguroni akses sekret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instaloni bibliotekën recurrentgemma
Përshpejtimi i harduerit falas Colab aktualisht është i pamjaftueshëm për të ekzekutuar këtë fletore. Nëse jeni duke përdorur Colab Pay As You Go ose Colab Pro , klikoni në Edit > Cilësimet e Notebook > Zgjidhni A100 GPU > Save për të aktivizuar përshpejtimin e harduerit.
Më pas, duhet të instaloni bibliotekën e Google DeepMind recurrentgemma
nga github.com/google-deepmind/recurrentgemma
. Nëse merrni një gabim në lidhje me "zgjidhësin e varësisë së pip", zakonisht mund ta shpërfillni atë.
pip install -q git+https://github.com/google-deepmind/recurrentgemma.git
Importoni biblioteka
Ky fletore përdor Flax (për rrjetet nervore), JAX thelbësor, SentencePiece (për tokenizimin), Chex (një bibliotekë shërbimesh për shkrimin e kodit të besueshëm JAX), Optax (bibliotekën e përpunimit dhe optimizimit të gradientit) dhe grupet e të dhënave TensorFlow.
import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools
import chex
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma
Ngarkoni modelin RecurrentGemma
- Ngarkoni modelin RecurrentGemma me
kagglehub.model_download
, i cili merr tre argumente:
-
handle
: Doreza e modelit nga Kaggle -
path
: (varg opsional) Shtegu lokal -
force_download
: (Boolean opsionale) Detyron të rishkarkojë modelin
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s] Extracting model files...
print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)
RECURRENTGEMMA_VARIANT: 2b-it
- Kontrolloni vendndodhjen e peshave të modelit dhe shënuesit, më pas vendosni variablat e rrugës. Drejtoria e tokenizuesit do të jetë në drejtorinë kryesore ku keni shkarkuar modelin, ndërsa peshat e modelit do të jenë në një nën-direktori. Për shembull:
- Skedari
tokenizer.model
do të jetë në/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - Pika e kontrollit të modelit do të jetë në
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Ngarko dhe përgatit grupin e të dhënave MTNT dhe tokenizuesin Gemma
Ju do të përdorni grupin e të dhënave MTNT (Chine Translation of Noisy Text) , i cili disponohet nga TensorFlow Datasets .
Shkarkoni pjesën e të dhënave anglisht në frëngjisht të grupit të të dhënave MTNT dhe më pas mostroni dy shembuj. Çdo mostër në grupin e të dhënave përmban dy hyrje: src
: fjalia origjinale në anglisht; dhe dst
: përkthimi përkatës frëngjisht.
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/35692 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...: 0%| … Generating test examples...: 0%| | 0/1020 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...: 0%| |… Generating valid examples...: 0%| | 0/811 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...: 0%| … Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data. Example 0: dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".' src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.' Example 1: dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?" src: b'Is Kameron a Little Salty About Her Lack of Air Time?'
Ngarko tokenizuesin Gemma, i ndërtuar duke përdorur sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Personalizojeni SentencePieceProcessor
in për detyrën e përkthimit nga anglishtja në frëngjisht. Meqenëse do të rregulloni mirë pjesën angleze të modelit RecurrentGemma (Griffin), duhet të bëni disa rregullime, si p.sh.
Prefiksi i hyrjes : Shtimi i një parashtese të përbashkët për çdo hyrje sinjalizon detyrën e përkthimit. Për shembull, mund të përdorni një kërkesë me një parashtesë si
Translate this into French: [INPUT_SENTENCE]
.Prapashtesa e fillimit të përkthimit : Shtimi i një prapashtese në fund të çdo prompt udhëzon modelin Gemma saktësisht se kur të fillojë procesi i përkthimit. Një linjë e re duhet të bëjë punën.
Shenjat e modelit të gjuhës : Modelet RecurrentGemma (Griffin) presin një shenjë "fillimi i sekuencës" në fillim të çdo sekuence. Në mënyrë të ngjashme, ju duhet të shtoni një shenjë "fundi i sekuencës" në fund të çdo shembulli trajnimi.
Ndërtoni një mbështjellës të personalizuar rreth SentencePieceProcessor
si më poshtë:
class GriffinTokenizer:
"""A custom wrapper around a SentencePieceProcessor."""
def __init__(self, spm_processor: spm.SentencePieceProcessor):
self._spm_processor = spm_processor
@property
def pad_id(self) -> int:
"""Fast access to the pad ID."""
return self._spm_processor.pad_id()
def tokenize(
self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> jax.Array:
"""
A tokenization function.
Args:
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_eos: If True, add an end of sentence token at the end of the output
sequence.
Returns:
Tokens corresponding to the input string.
"""
int_list = [self._spm_processor.bos_id()]
int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
if add_eos:
int_list.append(self._spm_processor.eos_id())
return jnp.array(int_list, dtype=jnp.int32)
def tokenize_tf_op(
self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True,
) -> tf.Tensor:
"""A TensforFlow operator for the `tokenize` function."""
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded
def to_string(self, tokens: jax.Array) -> str:
"""Convert an array of tokens to a string."""
return self._spm_processor.EncodeIds(tokens.tolist())
Provojeni duke instancuar GriffinTokenizer
tuaj të ri të personalizuar dhe më pas duke e aplikuar atë në një mostër të vogël të grupit të të dhënave MTNT:
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(
example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False
)
def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(example, add_eos=True)
tokenizer = GriffinTokenizer(vocab)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
'src': tokenize_source(tokenizer, x['src']),
'dst': tokenize_destination(tokenizer, x['dst'])
})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
Example 0: src: [ 2 49688 736 1280 6987 235292 108 651 2778 576 1080 104745 11982 5736 832 8995 901 780 3547 665 575 573 4589 235369 2778 235265 108] dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840 581 683 111452 581 533 235303 9776 4108 2459 679 485 235303 479 6728 579 1806 2499 709 29653 581 533 235303 101323 16054 1] Example 1: src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477 476 11709 230461 8045 3636 40268 576 4252 4897 235336 108] dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809 581 2032 69972 581 11495 1305 533 235303 65978 1654 1]
Ndërtoni një ngarkues të dhënash për të gjithë grupin e të dhënave MTNT:
@chex.dataclass(frozen=True)
class TrainingInput:
# Input tokens provided to the model.
input_tokens: jax.Array
# A mask that determines which tokens contribute to the target loss
# calculation.
target_mask: jax.Array
class DatasetSplit(enum.Enum):
TRAIN = 'train'
VALIDATION = 'valid'
class MTNTDatasetBuilder:
"""A data loader for the MTNT dataset."""
N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}
BUFFER_SIZE_SHUFFLE = 10_000
TRANSLATION_PREFIX = 'Translate this into French:\n'
TRANSLATION_SUFFIX = '\n'
def __init__(self,
tokenizer : GriffinTokenizer,
max_seq_len: int):
"""A constructor.
Args:
tokenizer: The tokenizer to use.
max_seq_len: The size of each sequence in a given batch.
"""
self._tokenizer = tokenizer
self._base_data = {
DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
}
self._max_seq_len = max_seq_len
def _tokenize_source(self, example: tf.Tensor):
"""A tokenization function for the source."""
return self._tokenizer.tokenize_tf_op(
example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
add_eos=False
)
def _tokenize_destination(self, example: tf.Tensor):
"""A tokenization function for the French translation."""
return self._tokenizer.tokenize_tf_op(example, add_eos=True)
def _pad_up_to_max_len(self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
"""Pad the given tensor up to sequence length of a batch."""
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(
input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
)
def _to_training_input(
self,
src_tokens: jax.Array,
dst_tokens: jax.Array,
) -> TrainingInput:
"""Build a training input from a tuple of source and destination tokens."""
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)
# You want to prevent the model from updating based on the source (input)
# tokens. To achieve this, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)
# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)
# You don't want to perform the backward on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)
return TrainingInput(input_tokens=tokens, target_mask=mask)
def get_train_dataset(self, batch_size: int, num_epochs: int):
"""Build the training dataset."""
# Tokenize each sample.
ds = self._base_data[DatasetSplit.TRAIN].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
# Convert them to training inputs.
ds = ds.map(lambda x, y: self._to_training_input(x, y))
# Remove the samples which are too long.
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
# Shuffle the dataset.
ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
# Repeat if necessary.
ds = ds.repeat(num_epochs)
# Build batches.
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_validation_dataset(self, batch_size: int):
"""Build the validation dataset."""
# Same as the training dataset, but no shuffling and no repetition
ds = self._base_data[DatasetSplit.VALIDATION].map(
lambda x : (self._tokenize_source(x['src']),
self._tokenize_destination(x['dst']))
)
ds = ds.map(lambda x, y: self._to_training_input(x, y))
ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
Provoni MTNTDatasetBuilder
duke instancuar përsëri GriffinTokenizer
in e personalizuar, më pas duke e aplikuar atë në grupin e të dhënave MTNT dhe duke marrë dy shembuj:
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
print(f'Example {idx}:')
for key, val in example.items():
print(f'{key}: {val}')
print()
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class '__main__.TrainingInput'> Example 0: input_tokens: [[ 2 49688 736 1280 6987 235292 108 12583 665 235265 108 2 6151 94975 1320 6238 235265 1 0 0] [ 2 49688 736 1280 6987 235292 108 4899 29960 11270 108282 235265 108 2 4899 79025 11270 108282 1 0] [ 2 49688 736 1280 6987 235292 108 26620 235265 108 2 26620 235265 1 0 0 0 0 0 0]] target_mask: [[False False False False False False False False False False False True True True True True True True False False] [False False False False False False False False False False False False False True True True True True True False] [False False False False False False False False False False True True True True False False False False False False]] Example 1: input_tokens: [[ 2 49688 736 1280 6987 235292 108 527 5174 1683 235336 108 2 206790 581 20726 482 2208 1654 1] [ 2 49688 736 1280 6987 235292 108 28484 235256 235336 108 2 120500 13832 1654 1 0 0 0 0] [ 2 49688 736 1280 6987 235292 108 235324 235304 2705 235265 108 2 235324 235304 19963 235265 1 0 0]] target_mask: [[False False False False False False False False False False False False True True True True True True True True] [False False False False False False False False False False False True True True True True False False False False] [False False False False False False False False False False False False True True True True True True False False]]
Konfiguro modelin
Para se të filloni të rregulloni modelin Gemma, duhet ta konfiguroni atë.
Ngarkoni pikën e kontrollit të modelit RecurrentGemma (Griffin) me metodën recurrentgemma.jax.utils.load_parameters
:
params = recurrentgemma.load_parameters(CKPT_PATH, "single_device")
Për të ngarkuar automatikisht konfigurimin e saktë nga pika e kontrollit të modelit RecurrentGemma, përdorni recurrentgemma.GriffinConfig.from_flax_params_or_variables
:
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
Instantoni modelin Griffin me recurrentgemma.jax.Griffin
:
model = recurrentgemma.Griffin(config)
Krijoni një sampler
me recurrentgemma.jax.Sampler
në krye të pikës/peshave të modelit RecurrentGemma dhe shënuesin për të kontrolluar nëse modeli juaj mund të kryejë përkthim:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)
Rregulloni mirë modelin
Në këtë seksion, ju do të:
- Përdorni klasën
gemma.transformer.Transformer
për të krijuar funksionin e kalimit përpara dhe humbjes. - Ndërtoni vektorët e maskës së pozicionit dhe vëmendjes për argumentet
- Ndërtoni një funksion hapi trajnimi me Lirin.
- Ndërtoni hapin e vlefshmërisë pa kalimin prapa.
- Krijo ciklin e trajnimit.
- Rregulloni mirë modelin Gemma.
Përcaktoni kalimin përpara dhe funksionin e humbjes duke përdorur klasën recurrentgemma.jax.griffin.Griffin
. RecurrentGemma Griffin
trashëgon nga flax.linen.Module
dhe ofron dy metoda thelbësore:
-
init
: Inicializon parametrat e modelit. -
apply
: Ekzekuton funksionin__call__
të modelit duke përdorur një grup të caktuar parametrash.
Meqenëse jeni duke punuar me pesha Gemma të trajnuara paraprakisht, nuk keni nevojë të përdorni funksionin init
.
def forward_and_loss_fn(
params,
*,
model: recurrentgemma.Griffin,
input_tokens: jax.Array, # Shape [B, L]
input_mask: jax.Array, # Shape [B, L]
positions: jax.Array, # Shape [B, L]
) -> jax.Array:
"""Forward pass and loss function.
Args:
params: model's input parameters.
model: Griffin model to call.
input_tokens: input tokens sequence, shape [B, L].
input_mask: tokens to ignore when computing the loss, shape [B, L].
positions: relative position of each token, shape [B, L].
Returns:
Softmax cross-entropy loss for the next-token prediction task.
"""
batch_size = input_tokens.shape[0]
# Forward pass on the input data.
# No attention cache is needed here.
# Exclude the last step as it does not appear in the targets.
logits, _ = model.apply(
{"params": params},
tokens=input_tokens[:, :-1],
segment_pos=positions[:, :-1],
cache=None,
)
# Similarly, the first token cannot be predicteds.
target_tokens = input_tokens[:, 1:]
target_mask = input_mask[:, 1:]
# Convert the target labels into one-hot encoded vectors.
one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
# Don't update on unwanted tokens.
one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]
# Normalization factor.
norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)
# Return the negative log-likelihood loss (NLL) function.
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor
Ndërtoni funksionin train_step
që kryen kalimin prapa dhe përditëson parametrat e modelit në përputhje me rrethanat, ku:
-
jax.value_and_grad
është për vlerësimin e funksionit të humbjes dhe gradientëve gjatë kalimeve përpara dhe prapa. -
optax.apply_updates
është për përditësimin e parametrave.
Params = Mapping[str, Any]
def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
"""Builds the position vector from the given tokens."""
pad_mask = example != pad_id
positions = jnp.cumsum(pad_mask, axis=-1)
# Subtract one for all positions from the first valid one as they are
# 0-indexed
positions = positions - (positions >= 1)
return positions
@functools.partial(
jax.jit,
static_argnames=['model', 'optimizer'],
donate_argnames=['params', 'opt_state'],
)
def train_step(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
pad_id: int,
example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
"""The train step.
Args:
model: The RecurrentGemma (Griffin) model.
params: The model's input parameters.
optimizer: The Optax optimizer to use.
opt_state: The input optimizer's state.
pad_id: The ID of the pad token.
example: The input batch.
Returns:
Training loss, updated parameters, updated optimizer state.
"""
positions = get_positions(example.input_tokens, pad_id)
# Forward and backward passes.
train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=positions,
)
# Update the parameters.
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return train_loss, params, opt_state
Ndërtoni funksionin validation_step
pa kalimin prapa:
@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
model: recurrentgemma.Griffin,
params: Params,
pad_id: int,
example: TrainingInput,
) -> jax.Array:
return forward_and_loss_fn(
params,
model=model,
input_tokens=example.input_tokens,
input_mask=example.target_mask,
positions=get_positions(example.input_tokens, pad_id),
)
Përcaktoni ciklin e trajnimit:
def train_loop(
model: recurrentgemma.Griffin,
params: Params,
optimizer: optax.GradientTransformation,
train_ds: Iterator[TrainingInput],
validation_ds: Iterator[TrainingInput],
num_steps: int | None = None,
eval_every_n: int = 20,
):
opt_state = jax.jit(optimizer.init)(params)
step_counter = 0
avg_loss=0
# The first round of the validation loss.
n_steps_eval = 0
eval_loss = 0
for val_example in validation_ds.as_numpy_iterator():
eval_loss += validation_step(
model, params, dataset_builder._tokenizer.pad_id, val_example
)
n_steps_eval += 1
print(f"Start, validation loss: {eval_loss/n_steps_eval}")
for train_example in train_ds:
train_loss, params, opt_state = train_step(
model=model,
params=params,
optimizer=optimizer,
opt_state=opt_state,
pad_id=dataset_builder._tokenizer.pad_id,
example=train_example,
)
step_counter += 1
avg_loss += train_loss
if step_counter % eval_every_n == 0:
eval_loss = 0
n_steps_eval = 0
val_iterator = validation_ds.as_numpy_iterator()
for val_example in val_iterator:
eval_loss += validation_step(
model,
params,
dataset_builder._tokenizer.pad_id,
val_example,
)
n_steps_eval +=1
avg_loss /= eval_every_n
eval_loss /= n_steps_eval
print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
avg_loss=0
if num_steps is not None and step_counter > num_steps:
break
return params
Këtu ju duhet të zgjidhni një optimizues (Optax). Për pajisjet me memorie më të vogël, duhet të përdorni SGD, pasi ka një gjurmë shumë më të ulët memorie. Për të arritur performancën më të mirë të akordimit, provoni Adam-W. Hiperparametrat optimale për çdo optimizues për detyrën e caktuar në këtë fletore janë dhënë në këtë shembull për pikën e kontrollit 2b-it
.
def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
# Don't put weight decay on the RGLRU, the embeddings and any biases
def enable_weight_decay(path: list[Any], _: Any) -> bool:
# Parameters in the LRU and embedder
path = [dict_key.key for dict_key in path]
if 'rg_lru' in path or 'embedder' in path:
return False
# All biases and scales
if path[-1] in ('b', 'scale'):
return False
return True
return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)
optimizer_choice = "sgd"
if optimizer_choice == "sgd":
optimizer = optax.sgd(learning_rate=1e-3)
num_steps = 300
elif optimizer_choice == "adamw":
optimizer = optax.adamw(
learning_rate=1e-4,
b2=0.96,
eps=1e-8,
weight_decay=0.1,
mask=griffin_weight_decay_mask,
)
num_steps = 100
else:
raise ValueError(f"Unknown optimizer: {optimizer_choice}")
Përgatitni grupet e të dhënave të trajnimit dhe vërtetimit:
# Choose a small sequence length size, so that everything fits in memory.
num_epochs = 1
batch_size = 1
sequence_length = 32
# Make the dataset builder.
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)
# Build the training dataset.
train_ds = dataset_builder.get_train_dataset(
batch_size=batch_size,
num_epochs=num_epochs,
).as_numpy_iterator()
# Build the validation dataset, with a limited number of samples for this demo.
validation_ds = dataset_builder.get_validation_dataset(
batch_size=batch_size,
).take(50)
Filloni të rregulloni modelin RecurrentGemma (Griffin) në një numër të kufizuar hapash ( num_steps
):
trained_params = train_loop(
model=model,
params=params,
optimizer=optimizer,
train_ds=train_ds,
validation_ds=validation_ds,
num_steps=num_steps,
)
Start, validation loss: 7.894117832183838 /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839 STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678 STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537 STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725 STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717 STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777 STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417 STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909 STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336 STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245 STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228 STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215 STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035 STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723 STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118
Si humbja e trajnimit ashtu edhe humbja e vlefshmërisë duhet të ishin ulur me numërimin e çdo hapi.
Për t'u siguruar që të dhënat tuaja përputhen me formatin e trajnimit, mos harroni të përdorni prefiksin Translate this into French:\n
dhe një karakter të linjës së re në fund. Kjo sinjalizon modelin që të fillojë përkthimin.
sampler.params = trained_params
output = sampler(
["Translate this into French:\nHello, my name is Morgane.\n"],
total_generation_steps=100,
)
print(output.text[0])
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Mais je m'appelle Morgane.
Mësoni më shumë
- Mund të mësoni më shumë rreth bibliotekës së Google DeepMind
recurrentgemma
në GitHub , e cila përmban vargjet e metodave dhe moduleve që keni përdorur në këtë tutorial, si p.sh.recurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
dherecurrentgemma.jax.Sampler
. - Bibliotekat e mëposhtme kanë faqet e tyre të dokumentacionit: core JAX , Flax , Chex , Optax dhe Orbax .
- Për dokumentacionin e detokenizuesit/detokenizuesit të
sentencepiece
, shikoni deponimin e GitHub tësentencepiece
të Google . - Për dokumentacionin
kagglehub
, shikoniREADME.md
në repon ekagglehub
GitHub të Kaggle . - Mësoni se si të përdorni modelet Gemma me Google Cloud Vertex AI .
- Nëse jeni duke përdorur TPU të Google Cloud (v3-8 dhe më të reja), sigurohuni që gjithashtu të përditësoni me paketën më të fundit
jax[tpu]
(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
jaxlib
, rinisni versionin ejax
dhe kontrolloni!pip list | grep jax
). Kjo mund të parandalojëRuntimeError
që mund të lindë për shkak të mospërputhjes së versionitjaxlib
dhejax
. Për më shumë udhëzime për instalimin e JAX, referojuni dokumenteve JAX . - Shikoni letrën RecurrentGemma: Moving Past Transformers for Efficient Open Language Models nga Google DeepMind.
- Lexoni letrën Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models nga Google DeepMind për të mësuar më shumë rreth arkitekturës së modelit të përdorur nga RecurrentGemma.